import os
from Bio import SeqIO
from matplotlib import patches
from pylab import *


def read_annotation_counts():
    filename = "firstnucleotide.txt"
    print("Reading", filename)
    handle = open(filename)
    line = next(handle)
    assert line.startswith("#")
    words = line[1:].split()
    assert words[0] == 'annotation'
    counts = {}
    samples = []
    for sample in words[1:]:
        dataset_library, acgt = sample.split(":")
        assert acgt == "A,C,G,T,a,c,g,t"
        dataset, library = dataset_library.split(",")
        sample = (dataset, library)
        samples.append(sample)
    for line in handle:
        words = line.strip().split("\t")
        annotation = words[0]
        if annotation in ['sense_proximal', "sense_distal",
                          "sense_upstream", 'sense_distal_upstream']:
            annotation = "sense"
        elif annotation in ["antisense", "antisense_distal",
                            "prompt", "antisense_distal_upstream"]:
            annotation = "antisense"
        counts[annotation] = {}
        for sample, word in zip(samples, words[1:]):
            dataset, library = sample
            if dataset == "MiSeq" and not library.startswith("t"):
                # Include time course samples only
                continue
            if dataset == "HiSeq" and library == "t01_r3":
                # Skipping HiSeq negative control library using water as input material
                continue
            A, C, G, T, a, c, g, t = word.split(",")
            A = int(A)
            C = int(C)
            G = int(G)
            T = int(T)
            a = int(a)
            c = int(c)
            g = int(g)
            t = int(t)
            values = array([A, a, C, c, G, g, T, t])
            if dataset not in counts[annotation]:
                counts[annotation][dataset] = {}
            if library not in counts[annotation][dataset]:
                counts[annotation][dataset][library] = zeros(8)
            counts[annotation][dataset][library] += values
    handle.close()
    return counts

counts = read_annotation_counts()

positions = []
label_positions = []
labels = []
start = 0
end = 0
for dataset in counts["unmapped"]:
    end += len(counts["unmapped"][dataset])
    position = (start + end) / 2.0
    label_positions.append(position)
    if dataset == "MiSeq":
        labels.append("short capped RNAs,\npaired-end libraries")
    elif dataset == "HiSeq":
        labels.append("short capped RNAs,\nsingle-end libraries")
    elif dataset == "CAGE":
        labels.append("long capped RNAs,\nCAGE libraries")
    else:
        raise Exception("Unexpected data set %s" % dataset)
    if dataset != "CAGE":
        positions.append(end)
    start = end

print("Removing unmapped reads")
del counts["unmapped"]

colors = ('green', 'green', 'blue', 'blue', 'yellow', 'yellow', 'red', 'red')

annotations = ["chrM",
               "rRNA",
               "Pol-II short RNA",
               "Pol-III short RNA",
               "intronic short RNA",
               "short RNA precursor",
               "histone",
               "sense",
               "antisense",
               "FANTOM5_enhancer",
               "roadmap_enhancer",
               "roadmap_dyadic",
               "novel_enhancer_CAGE",
               "novel_enhancer_HiSeq",
               "other_intergenic",
              ]

fig = figure(figsize=(6,12))

ax = fig.add_subplot(111)
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(labelcolor='w', top=False, bottom=False, left=False, right=False)
ax.set_ylabel("Frequency of first nucleotide [%]", fontsize=8)


letters = ("A", "a", "C", "c", "G", "g", "T", "t")

for k, annotation in enumerate(annotations):
    fig.add_subplot(5, 3, k+1)
    bottom = None
    for i, color in enumerate(colors):
        row = []
        for dataset in counts[annotation]:
            for library in counts[annotation][dataset]:
                values = counts[annotation][dataset][library]
                total = sum(values) + 1.e-99
                values = 100.0 * values / total
                row.append(values[i])
        row = array(row)
        if bottom is None:
            m = len(row)
            bottom = zeros(m)
            x = arange(m) + 0.5
        letter = letters[i]
        if letter.islower():  # mismatched
            alpha = 0.2
        elif letter.isupper():  # matched
            alpha = 1.0
        else:
            raise Exception
        bar(x, row, width=1.0, bottom=bottom, color=color, alpha=alpha)
        bottom += row
    for position in positions:
        plot([position, position], [0, 100], 'k--', linewidth=0.5)
    title(annotation, fontsize=8, pad=2)
    xticks([])
    if k % 3 == 0:
        yticks(fontsize=8)
    else:
        yticks([])
    xlim(0, m)
    ylim(0, 100)
    if k // 3 == 4:
        xticks(label_positions, labels, fontsize=8, rotation=90, horizontalalignment='center', verticalalignment='top')

subplots_adjust(bottom=0.36, top=0.98, left=0.09, right=0.98, wspace=0.32,hspace=0.30)

A = patches.Patch(facecolor='green', edgecolor='black', label='A')
C = patches.Patch(facecolor='blue', edgecolor='black', label='C')
G = patches.Patch(facecolor='yellow', edgecolor='black', label='G')
T = patches.Patch(facecolor='red', edgecolor='black', label='T     matched to the genome')
a = patches.Patch(facecolor='green', edgecolor='black', alpha=0.2, label='A')
c = patches.Patch(facecolor='blue', edgecolor='black', alpha=0.2, label='C')
g = patches.Patch(facecolor='yellow', edgecolor='black', alpha=0.2, label='G')
t = patches.Patch(facecolor='red', edgecolor='black', alpha=0.2, label='T     mismatched to the genome')

legend(handles=[A, a, C, c, G, g, T, t], fontsize=8, bbox_to_anchor=(0.09,0.0,0.89,0.32), bbox_transform=fig.transFigure, frameon=False, ncol=4,loc='center')

filename = "figure_g_enrichment_timecourse.svg"
print("Saving figure to %s" % filename)
savefig(filename)

filename = "figure_g_enrichment_timecourse.png"
print("Saving figure to %s" % filename)
savefig(filename)
